import java.util.Scanner;

public class One {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        long n=in.nextInt();
        long m=in.nextInt();
        long a=in.nextInt();
        long b=in.nextInt();
        long ret=0;
        for (long x = 0; x <=Math.min(n/2,m); x++) {
            long y=Math.min(n-x*2,(m-x)/2);
            ret=Math.max(ret,a*x+b*y);
        }
        System.out.println(ret);


    }
}
